
package com.shiku.imserver.hander;

import com.shiku.imserver.common.http.HttpRequest;
import com.shiku.imserver.common.http.HttpRequestDecoder;
import com.shiku.imserver.common.http.HttpResponse;
import com.shiku.imserver.common.http.HttpResponseEncoder;
import com.shiku.imserver.common.message.PacketVO;
import com.shiku.imserver.common.packets.ImPacket;
import com.shiku.imserver.common.ws.IWsMsgHandler;
import com.shiku.imserver.common.ws.WsPacket;
import com.shiku.imserver.common.ws.WsRequest;
import com.shiku.imserver.common.ws.WsResponse;
import com.shiku.imserver.common.ws.WsServerConfig;
import com.shiku.imserver.common.ws.WsServerDecoder;
import com.shiku.imserver.common.ws.WsServerEncoder;
import com.shiku.imserver.common.ws.WsSessionContext;
import com.shiku.imserver.service.IMBeanUtils;

import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tio.core.ChannelContext;
import org.tio.core.GroupContext;
import org.tio.core.Tio;
import org.tio.core.exception.AioDecodeException;
import org.tio.core.intf.Packet;
import org.tio.http.common.HeaderName;
import org.tio.http.common.HeaderValue;
import org.tio.http.common.HttpResponseStatus;
import org.tio.http.common.HeaderValue.Connection;
import org.tio.http.common.HeaderValue.Upgrade;
import org.tio.utils.hutool.StrUtil;
import org.tio.websocket.common.Opcode;
import org.tio.websocket.common.util.BASE64Util;
import org.tio.websocket.common.util.SHA1Util;

public class WsServerAioHandler extends AbstractProtocolHandler {
    private static Logger log = LoggerFactory.getLogger(WsServerAioHandler.class);
    private WsServerConfig wsServerConfig;
    private IWsMsgHandler wsMsgHandler;

    public WsServerAioHandler(WsServerConfig wsServerConfig, IWsMsgHandler wsMsgHandler) {
        this.wsServerConfig = wsServerConfig;
        this.wsMsgHandler = wsMsgHandler;
    }

    @Override
    public WsRequest decode(ByteBuffer buffer, int limit, int position, int readableLength, ChannelContext channelContext) throws AioDecodeException {
        WsSessionContext wsSessionContext = (WsSessionContext) channelContext.getAttribute();
        if (!wsSessionContext.isHandshaked()) {
            HttpRequest request = HttpRequestDecoder.decode(buffer, limit, position, readableLength, channelContext, this.wsServerConfig);
            if (request == null) {
                return null;
            } else {
                HttpResponse httpResponse = updateWebSocketProtocol(request, channelContext);
                if (httpResponse == null) {
                    throw new AioDecodeException("http协议升级到websocket协议失败");
                } else {
                    httpResponse.setCommand((short) 2);
                    wsSessionContext.setHandshakeRequest(request);
                    wsSessionContext.setHandshakeResponse(httpResponse);
                    wsSessionContext.setHandshaked(true);
                    Tio.send(channelContext, httpResponse);
                    WsRequest wsRequestPacket = new WsRequest();
                    wsRequestPacket.setHandShake(true);
                    wsRequestPacket.setCommand((short) 1);
                    return wsRequestPacket;
                }
            }
        } else {
            WsRequest websocketPacket = WsServerDecoder.decode(buffer, channelContext);
            if (null != websocketPacket) {
                websocketPacket.setWSBytes(websocketPacket.getBytes());
                if (websocketPacket.getWsOpcode() == Opcode.CLOSE) {
                    websocketPacket.setCommand((short) 7);
                }
            }

            return websocketPacket;
        }
    }

    @Override
    public Packet decode(ByteBuffer buffer, ChannelContext channelContext) throws AioDecodeException {
        return null;
    }

    @Override
    public ByteBuffer encode(Packet packet, GroupContext groupContext, ChannelContext channelContext) {
        ImPacket imPacket = (ImPacket) packet;
        if (imPacket.getCommand() == 2) {
            WsSessionContext imSessionContext = (WsSessionContext) channelContext.getAttribute();
            HttpResponse handshakeResponse = imSessionContext.getHandshakeResponse();

            try {
                return HttpResponseEncoder.encode(handshakeResponse, groupContext, channelContext);
            } catch (UnsupportedEncodingException var8) {
                log.error(var8.toString(), var8);
                return null;
            }
        } else {
            WsPacket wsPacket = new WsPacket(imPacket.getCommand(), imPacket.getBytes());
            return WsServerEncoder.encode(wsPacket, groupContext, channelContext);
        }
    }

    public WsServerConfig getHttpConfig() {
        return this.wsServerConfig;
    }

    @Override
    public void handler(Packet packet, ChannelContext channelContext) throws Exception {
        ImPacket impacket = (ImPacket) packet;
        PacketVO result = IMBeanUtils.getMessageProcess().dispatch(impacket, channelContext);
        if (null != result) {
            Tio.bSend(channelContext, new WsPacket(result.getCmd(), result.getBytes()));
        }

    }

    private WsResponse h(WsRequest websocketPacket, byte[] bytes, Opcode opcode, ChannelContext channelContext) throws Exception {
        WsResponse wsResponse = null;
        if (opcode == Opcode.TEXT) {
            if (bytes != null && bytes.length != 0) {
                String text = new String(bytes, this.wsServerConfig.getCharset());
                Object retObj = this.wsMsgHandler.onText(websocketPacket, text, channelContext);
                String methodName = "onText";
                wsResponse = this.processRetObj(retObj, methodName, channelContext);
                return wsResponse;
            } else {
                Tio.remove(channelContext, "错误的websocket包，body为空");
                return null;
            }
        } else {
            Object retObj;
            String methodName;
            if (opcode == Opcode.BINARY) {
                if (bytes != null && bytes.length != 0) {
                    retObj = this.wsMsgHandler.onBytes(websocketPacket, bytes, channelContext);
                    methodName = "onBytes";
                    wsResponse = this.processRetObj(retObj, methodName, channelContext);
                    return wsResponse;
                } else {
                    Tio.remove(channelContext, "错误的websocket包，body为空");
                    return null;
                }
            } else if (opcode != Opcode.PING && opcode != Opcode.PONG) {
                if (opcode == Opcode.CLOSE) {
                    retObj = this.wsMsgHandler.onClose(websocketPacket, bytes, channelContext);
                    methodName = "onClose";
                    wsResponse = this.processRetObj(retObj, methodName, channelContext);
                    return wsResponse;
                } else {
                    Tio.remove(channelContext, "错误的websocket包，错误的Opcode");
                    return null;
                }
            } else {
                log.debug("收到" + opcode);
                return null;
            }
        }
    }

    private WsResponse processRetObj(Object obj, String methodName, ChannelContext channelContext) throws Exception {
        WsResponse wsResponse = null;
        if (obj == null) {
            return null;
        } else if (obj instanceof String) {
            String str = (String) obj;
            wsResponse = WsResponse.fromText(str, this.wsServerConfig.getCharset());
            return wsResponse;
        } else if (obj instanceof byte[]) {
            wsResponse = WsResponse.fromBytes((byte[]) ((byte[]) obj));
            return wsResponse;
        } else if (obj instanceof WsResponse) {
            return (WsResponse) obj;
        } else if (obj instanceof ByteBuffer) {
            byte[] bs = ((ByteBuffer) obj).array();
            wsResponse = WsResponse.fromBytes(bs);
            return wsResponse;
        } else {
            log.error("{} {}.{}()方法，只允许返回byte[]、ByteBuffer、WsResponse或null，但是程序返回了{}", new Object[]{channelContext, this.getClass().getName(), methodName, obj.getClass().getName()});
            return null;
        }
    }

    public void setHttpConfig(WsServerConfig httpConfig) {
        this.wsServerConfig = httpConfig;
    }

    public static HttpResponse updateWebSocketProtocol(HttpRequest request, ChannelContext channelContext) {
        Map<String, String> headers = request.getHeaders();
        String Sec_WebSocket_Key = (String) headers.get("sec-websocket-key");
        if (StrUtil.isNotBlank(Sec_WebSocket_Key)) {
            String Sec_WebSocket_Key_Magic = Sec_WebSocket_Key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
            byte[] key_array = SHA1Util.SHA1(Sec_WebSocket_Key_Magic);
            String acceptKey = BASE64Util.byteArrayToBase64(key_array);
            HttpResponse httpResponse = new HttpResponse(request);
            httpResponse.setStatus(HttpResponseStatus.C101);
            Map<HeaderName, HeaderValue> respHeaders = new HashMap();
            respHeaders.put(HeaderName.Connection, Connection.Upgrade);
            respHeaders.put(HeaderName.Upgrade, Upgrade.WebSocket);
            respHeaders.put(HeaderName.Sec_WebSocket_Accept, HeaderValue.from(acceptKey));
            httpResponse.addHeaders(respHeaders);
            return httpResponse;
        } else {
            return null;
        }
    }
}
